if __name__ == '__main__':
    # This is a terrible hack just to be able to execute this file directly
    import sys
    sys.path.insert(0, '../../')

import random
import gym
from envs.minigrid.adversarial import *
from envs.minigrid.kitchen import KitchenEnv
from envs.minigrid.traffic import TrafficEnv

class MinigridEnv(gym.Env):
    """
    A simple wrapper for a gym-minigrid environment. This implements propositions on top of the minigrid. 
    """

    def __init__(self, env, letters, timeout = 100):
        """
            ## env is the wrapped MiniGrid environment
        """
        self.env = env
        self.letters = letters
        self.letter_types = list(set(letters))
        self.letter_types.sort()
        self.action_space = gym.spaces.Discrete(3) # Only use left, right, and forward actions
        self.observation_space = env.observation_space['image']
        self.num_episodes = 0
        self.time = 0
        self.timeout = timeout

    def step(self, action):
        obs, reward, done, _ = self.env.step(action)
        self.time += 1
        if self.time >= self.timeout:
            done = True
        return obs['image'], reward, done, _

    def seed(self, seed=None):
        random.seed(seed)
        self.env.seed(seed)

    def reset(self):
        """
        This function resets the world and collects the first observation.
        """
        self.num_episodes += 1
        self.time  = 0
        return self.env.reset()['image']

    def render(self):
        self.env.render()

    def get_events(self):
        return self.env.get_events()

    def get_propositions(self):
        return self.letter_types

    def get_sync_rm_func(self):
        try:
            return self.env.get_sync_rm_func()
        except AttributeError:
            return None

    def get_sync_rm_belief_func(self):
        try:
            return self.env.get_sync_rm_belief_func()
        except AttributeError:
            return None

class AdversarialMinigridEnv(MinigridEnv):
    def __init__(self):
        super().__init__(AdversarialEnv9x9(), 'abc', 100)

class KitchenMinigridEnv(MinigridEnv):
    def __init__(self):
        super().__init__(KitchenEnv(), 'abcd', 400)

class UnlockedKitchenMinigridEnv(MinigridEnv):
    def __init__(self):
        super().__init__(KitchenEnv(is_locked=False), 'abcd', 400)

class RandomKitchenMinigridEnv(MinigridEnv):
    def __init__(self):
        super().__init__(KitchenEnv(randomize_chores=True), 'abcd', 400)

class TrafficMinigridEnv(MinigridEnv):
    def __init__(self):
        env = TrafficEnv()
        super().__init__(env, 'cdt', 100)
        self.action_space = env.action_space

if __name__ == '__main__':
    AdversarialMinigridEnv()

